{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Trading compute for memory in PyTorch models using Checkpointing\n",
    "\n",
    "Given the fact that deep learning explorations are bottlenecked by device memory constraints, we implemented a gradient checkpointing technique described in [1] and using this technique, we can achieve 4x-10x larger models ranging across convolutional networks, recurrent networks and medical imaging applications. The implementations of memory optimized versions of models along with their baselines is available [here](https://github.com/prigoyal/pytorch_memonger/tree/master/models). The results of checkpointing are below:\n",
    "\n",
    "<img src=\"results.png\">\n",
    "\n",
    "As we see from the above results, for ResNet-1001 model, baseline implementation can only fit 15 images and for better packing and layers optimization, we use minibatch size of 8 per run and do 4 runs to simulate minibatch size 32 training. But in the checkpointed model, we can fit ~52 images in minibatch and hence use 32 images in each run. As evident from the results, per iteration <i> speed improves by ~10% </i> and the model accuracy will be better because of stable batch normalization layers. \n",
    "\n",
    "###  Gradient Checkpointing\n",
    "\n",
    "Gradient Checkpointing technique reorders the forward and reverse computation of the model so as to reduce the maximal length of the autograd tape. This is achieved by dividing the model or parts of the model in various segments and executing the segments without taping them in the forward pass i.e. their taping is delayed until the backward pass. This results in not storing the activations of those segments in memory. During the backward pass, since we need activations to compute the gradients, the forward pass is done again on the segment and the pass is taped. In this way, at any given time, the tape length is the maximum of tape length of various segments and not the sum.\n",
    "\n",
    "##  Using Checkpointing for PyTorch models\n",
    "\n",
    "Now in this tutorial, we will see how to use checkpointing on various models in PyTorch.\n",
    "\n",
    "<b>NOTE:</b> This tutorial needs PyTorch master branch which can be installed by following instructions [here](https://github.com/pytorch/pytorch#from-source)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# some standard imports\n",
    "import torch\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Checkpointing sequential models\n",
    "\n",
    "First, let's create a simple Sequential model and checkpoint it. We can also verify that the checkpointing doesn't \n",
    "change the value of gradients or the activations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# import the checkpoint API\n",
    "from torch.utils.checkpoint import checkpoint_sequential\n",
    "import torch.nn as nn\n",
    "\n",
    "# create a simple Sequential model\n",
    "model = nn.Sequential(\n",
    "    nn.Linear(100, 50),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(50, 20),\n",
    "    nn.ReLU(),\n",
    "    nn.Linear(20, 5),\n",
    "    nn.ReLU()\n",
    ")\n",
    "\n",
    "# create the model inputs\n",
    "input_var = Variable(torch.randn(1, 100), requires_grad=True)\n",
    "\n",
    "# set the number of checkpoint segments\n",
    "segments = 2\n",
    "\n",
    "# get the modules in the model. These modules should be in the order\n",
    "# the model should be executed\n",
    "modules = [module for k, module in model._modules.items()]\n",
    "\n",
    "# now call the checkpoint API and get the output\n",
    "out = checkpoint_sequential(modules, segments, input_var)\n",
    "\n",
    "# run the backwards pass on the model. For backwards pass, for simplicity purpose, \n",
    "# we won't calculate the loss and rather backprop on out.sum()\n",
    "model.zero_grad()\n",
    "out.sum().backward()\n",
    "\n",
    "# now we save the output and parameter gradients that we will use for comparison purposes with\n",
    "# the non-checkpointed run.\n",
    "output_checkpointed = out.data.clone()\n",
    "grad_checkpointed = {}\n",
    "for name, param in model.named_parameters():\n",
    "    grad_checkpointed[name] = param.grad.data.clone()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have executed the checkpointed pass on the model, let's also run the non-checkpointed model and verify that the checkpoint API doesn't change the model outputs or the parameter gradients."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# non-checkpointed run of the model\n",
    "original = model\n",
    "\n",
    "# create a new variable using the same tensor data\n",
    "x = Variable(input_var.data, requires_grad=True)\n",
    "\n",
    "# get the model output and save it to prevent any modifications\n",
    "out = original(x)\n",
    "out_not_checkpointed = out.data.clone()\n",
    "\n",
    "# calculate the gradient now and save the parameter gradients values\n",
    "original.zero_grad()\n",
    "out.sum().backward()\n",
    "grad_not_checkpointed = {}\n",
    "for name, param in model.named_parameters():\n",
    "    grad_not_checkpointed[name] = param.grad.data.clone()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now that we have done the checkpointed and non-checkpointed pass of the model and saved the output and parameter gradients, let's compare their values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# compare the output and parameters gradients\n",
    "self.assertEqual(out_checkpointed, out_not_checkpointed)\n",
    "for name in grad_checkpointed:\n",
    "    self.assertEqual(grad_checkpointed[name], grad_not_checkpointed[name])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So, from this example, we can see that it's very easy to use checkpointing on Sequential models and that the checkpoint API doesn't alter any data. The Checkpoint API implementation is based on autograd and hence there is no need for explicitly specifying what the execution of backwards should look like. For more examples on checkpointing sequential models like ResNet and DenseNet, see the examples [here](https://github.com/prigoyal/pytorch_memonger/blob/master/models/optimized/resnet_new.py) and [here](https://github.com/prigoyal/pytorch_memonger/blob/master/models/optimized/densenet_new.py).\n",
    "\n",
    "<b>TIP: </b> As you might have notices in the above examples of Pre-activation ResNet/ DenseNet, the resnet block is directly appended to ```self.features``` and similarly Denselayer is directly appended to ```self.features```. Doing this allows for better granular checkpointing otherwise, wrapping Denselayers in DenseBlock (as is done in baseline), or resnetBlocks in ResNet layers would limit checkpointing to DenseBlock or ResNet layer which consumes more memory.\n",
    "\n",
    "<b> IMPORTANT NOTE</b>: The list of modules should strictly define the model execution and it shouldn't be arbitrarily arranged. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Checkpointing non-sequential models\n",
    "So far, we have seen how to do checkpointing on Sequential models like ResNet, DenseNet etc. for which we can easily extract the modules and call the ```checkpoint_sequential```. But there are other models like LSTM etc. which can't truly use ```checkpoint_sequential```. Now, let's see how we can use checkpointing in such cases. We will take an example for RNN model which is available [here](https://github.com/pytorch/examples/blob/master/word_language_model/model.py). For simplicity, the weight inits and hidden unit inits have been omitted from definition below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# some standard imports\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "\n",
    "# define the model \n",
    "class RNNModel(nn.Module):\n",
    "    \"\"\"Container module with an encoder, a recurrent module, and a decoder.\"\"\"\n",
    "\n",
    "    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):\n",
    "        super(RNNModel, self).__init__()\n",
    "        self.drop = nn.Dropout(dropout)\n",
    "        self.encoder = nn.Embedding(ntoken, ninp)\n",
    "        self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)\n",
    "        self.decoder = nn.Linear(nhid, ntoken)\n",
    "        self.nhid = nhid\n",
    "        self.nlayers = nlayers\n",
    "\n",
    "    def forward(self, input, hidden):\n",
    "        emb = self.drop(self.encoder(input))\n",
    "        output, hidden = self.rnn(emb, hidden)\n",
    "        output = self.drop(output)\n",
    "        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))\n",
    "        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In the above LSTM based model, we can optimize memory by checkpointing ```self.rnn()``` execution by segmenting the BPTT sequence length. In order to do this, we will use ```checkpoint()``` API call which requires passing a ```run_function``` that describes what should the checkpoint API run in forward pass for a given segment of the model. So let's define such a run function and modify the model as below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# some standard imports\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "\n",
    "# import the checkpoint API \n",
    "import torch.utils.checkpoint as checkpoint\n",
    "\n",
    "# modified model definition to use checkpointing\n",
    "class RNNModel(nn.Module):\n",
    "    \"\"\"Container module with an encoder, a recurrent module, and a decoder.\"\"\"\n",
    "\n",
    "    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, dropout=0.5):\n",
    "        super(RNNModel, self).__init__()\n",
    "        self.drop = nn.Dropout(dropout)\n",
    "        self.encoder = nn.Embedding(ntoken, ninp)\n",
    "        self.rnn = getattr(nn, rnn_type)(ninp, nhid, nlayers, dropout=dropout)\n",
    "        self.decoder = nn.Linear(nhid, ntoken)\n",
    "        self.nhid = nhid\n",
    "        self.nlayers = nlayers\n",
    "\n",
    "    # define a run_function that returns a function which takes the inputs variables as \n",
    "    # arguments, unpacks those inputs, perform the computation and returns the results. \n",
    "    # Note that the start, end are not inuts to custom_forward function below.    \n",
    "    def run_function(self, start, end):\n",
    "        def custom_forward(*inputs):\n",
    "            output, hidden = self.rnn(\n",
    "                inputs[0][start:(end+1)], (inputs[1], inputs[2])\n",
    "            )\n",
    "            return output, hidden[0], hidden[1]\n",
    "        return custom_forward\n",
    "    \n",
    "    def forward(self, input, hidden, segments):\n",
    "        emb = self.drop(self.encoder(input))\n",
    "        # checkpoint self.rnn() computation\n",
    "        output = []\n",
    "        segment_size = len(modules) // segments\n",
    "        for start in range(0, segment_size * (segments - 1), segment_size):\n",
    "            end = start + segment_size - 1\n",
    "            # Note that if there are multiple inputs, we pass them as as is without \n",
    "            # wrapping in a tuple etc.\n",
    "            out = checkpoint.checkpoint(\n",
    "                self.run_function(start, end), emb, hidden[0], hidden[1]\n",
    "            )\n",
    "            output.append(out[0])\n",
    "            hidden = (out[1], out[2])\n",
    "        out = checkpoint.checkpoint(\n",
    "            self.run_function(end + 1, len(modules) - 1), emb, hidden[0], hidden[1]\n",
    "        )\n",
    "        output.append(out[0])\n",
    "        hidden = (out[1], out[2])\n",
    "        \n",
    "        output = torch.cat(output, 0)\n",
    "        hidden = (out[1], out[2])\n",
    "        output = self.drop(output)\n",
    "        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))\n",
    "        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In summary, for checkpointing any part of the model, you can define a ```run_function``` which describes what should the checkpoint API run in the forward pass. Please see the above example and comments on the code.\n",
    "\n",
    "<b> NOTE: </b> In case of checkpointing, if all the inputs don't require grad but the outputs do, then if the inputs are passed as is, the output of Checkpoint will be variable which don't require grad and autograd tape will break there. To get around, you can pass a dummy input which requires grad but isn't necessarily used in computation.\n",
    "\n",
    "### Checkpointing any module\n",
    "\n",
    "If you want to checkpoint only a certain module in your network, it's very easy to checkpoint it. Let's create a dummy model comprising of many layers and we want to checkpoint one module in it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# some standard imports\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable, Function\n",
    "import torch.utils.checkpoint as checkpoint\n",
    "\n",
    "class ConvBNReLU(nn.Module):\n",
    "    \n",
    "    def __init__(self, in_planes, out_planes):\n",
    "        \n",
    "        super(ConvBNReLU, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, bias=False)\n",
    "        self.bn1 = nn.BatchNorm2d(out_planes)\n",
    "        self.relu1 = nn.ReLU(inplace=True)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = self.relu1(self.bn1(self.conv1(x)))\n",
    "        return out\n",
    "\n",
    "class DummyNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(DummyNet, self).__init__()\n",
    "        self.features = nn.Sequential(OrderedDict([\n",
    "            ('conv1', nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)),\n",
    "            ('bn1', nn.BatchNorm2d(16)),\n",
    "            ('relu1', nn.ReLU(inplace=True)),\n",
    "        ]))\n",
    "\n",
    "        # The module that we want to checkpoint\n",
    "        self.module = ConvBNReLU(16, 64) \n",
    "        \n",
    "        self.final_module = ConvBNReLU(64, 64)\n",
    "    \n",
    "    def custom(self, module):\n",
    "        def custom_forward(*inputs):\n",
    "            inputs = module(inputs[0])\n",
    "            return inputs\n",
    "        return custom_forward\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = self.features(x)\n",
    "        out = checkpoint.checkpoint(self.custom(self.module), out)\n",
    "        out = self.final_module(out)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Handling a few special layers in checkpointing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "While using checkpointing, here are a few useful things that should be followed in order to avoid any issues with the output and grad values. \n",
    "\n",
    "### Dropout Layer\n",
    "The checkpointing API runs the forward pass twice on the checkpointed segment of the model: once during the forward pass which is untaped, and second times during the backward pass where the forward pass is taped. Since dropout drops values arbitrarily, <b> avoid </b> having dropout in the checkpointed model part otherwise it will result in incorrect output. Instead, split the checkpointing into two parts: before and after the dropout layer\n",
    "\n",
    "### Batch Normalization layer\n",
    "Batch normalization layer maintains the running mean and variance stats depending on the current minibatch and everytime a forward pass is run, the stats are updated based on the ```momentum``` value. In checkpointing, running the forward pass twice on a model segment in the same iteration will result in updating mean and stats value. In order to avoid this, use the ```new_momentum = sqrt(momentum)``` as the momentum value. The mathematical justification is as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "original_rm, original_nm = running_mean, new_mean\n",
    "\n",
    "# Originally, we have a running_mean to start from\n",
    "momemtum = 0.9\n",
    "# running mean after the first forward pass. We want to match this \n",
    "# when we use checkpointing\n",
    "running_mean1 = original_nm * 0.1 + original_rm * 0.9\n",
    "\n",
    "# With checkpointing, we use momentum = sqrt(0.9) and make two passes\n",
    "momentum = sqrt(0.9)\n",
    "runing_mean = original_nm * (1 - sqrt(0.9)) + original_rm * sqrt(0.9)      # Eq (1) \n",
    "# Note that in the second forward pass, since the inputs are same, the \n",
    "# new_mean value doesn't change\n",
    "running_mean2 = original_nm * (1 - sqrt(0.9)) + running_mean * sqrt(0.9)   # Eq (2) \n",
    "\n",
    "#Substitue value of running_mean from Eq(1) to Eq(2)\n",
    "assert running_mean2 == running_mean1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Decoder layer in Word language model\n",
    "In the Pytorch example for [word language model](https://github.com/pytorch/examples/blob/master/word_language_model/model.py#L20), the decoder is a single fully connected layer that operates on the entire sequence length and batch size. This layer is followed by Cross Entropy loss calculation which first calculates the softmax on the output of decoder. Now, as the sequence length is increases, the decoder activation size and softmax output size increases as well and is very large in size (almost entire memory consumption is because of decoder layer). In order to get around, we can split the decoder layer in segments, calculate the loss on each segment and do the loss backward on each segment. This saves tremendous amount of memory. For reference see [this](https://github.com/prigoyal/pytorch_memonger/blob/master/models/optimized/word_language_model_new.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "[1]. Siskind, Jeffrey Mark, and Barak A. Pearlmutter. \"Divide-and-Conquer Checkpointing for Arbitrary Programs with No User Annotation.\" arXiv preprint arXiv:1708.06799 (2017).\n",
    "\n",
    "[2]. T. Chen, B. Xu, C. Zhang, and C. Guestrin. Training deep nets with sublinear memory cost.\n",
    "arXiv preprint arXiv:1604.06174, 2016\n",
    "\n",
    "[3]. Gomez, A. N., Ren, M., Urtasun, R., & Grosse, R. B. (2017). The reversible residual network: Backpropagation without storing activations. In Advances in Neural Information Processing Systems (pp. 2211-2221)."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3rc1+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
